#!/usr/bin/env python
# _*_coding:utf-8_*_

"""
@Author:  zhaojianghua
@Software: PyCharm
"""

import random
from PIL import Image, ImageDraw
from ivkcore import common

import cv2
import numpy as np


class MazeEnv(object):

    def __init__(self):
        self.maze = [[0, 0, 0, 0, 0],
                     [0, 0, 0, 0, 0],
                     [0, 0, 2, 2, 0],
                     [0, 0, 0, 0, 0],
                     [0, 0, 0, 0, 1]]
        self.state = [0, 0]
        self.is_done = False

    def step(self, action):
        """
        :param action: 0, 1, 2, 3: up, down, left, right
        :return: observation, reward, is_done
        """
        if self.is_done:
            print("Warning: env is done, reset it!!!")
            return self.state, 0, self.is_done

        if action == 0:
            self.state[1] -= 1
        elif action == 1:
            self.state[1] += 1
        elif action == 2:
            self.state[0] -= 1
        elif action == 3:
            self.state[0] += 1
        else:
            raise ValueError("Unknown action:", action)
        if self.state[0] < 0 or self.state[0] >= 5 or \
                self.state[1] < 0 or self.state[1] >= 5:
            self.state = [-1, -1]
            self.is_done = True
            return self.state, -1, self.is_done
        if self.maze[self.state[0]][self.state[1]] == 1:
            self.is_done = True
            return self.state, 1, self.is_done
        elif self.maze[self.state[0]][self.state[1]] == 2:
            self.is_done = True
            return self.state, -1, self.is_done
        else:
            return self.state, 0, False

    def reset(self):
        self.state = [0, 0]
        self.is_done = False

    def random_state(self):
        while True:
            a = random.randint(0, 4)
            b = random.randint(0, 4)
            if self.maze[a][b] == 0:
                self.state = [a, b]
                break

    def render(self):
        panel = Image.new("RGB", (500, 500), (255, 255, 120))
        draw = ImageDraw.Draw(panel)
        line_ofs = 1
        line_color = (10, 10, 10)
        for i in range(0, 6):
            draw.line([(0, i*100-line_ofs), (500, i*100-line_ofs)], fill=line_color, width=4)
            draw.line([(i*100-line_ofs, 0), (i*100-line_ofs, 500)], fill=line_color, width=4)
        obj_ofs = 5
        reward_color = (245, 25, 25)
        obstacle_color = (25, 25, 245)
        for i in range(5):
            for j in range(5):
                if self.maze[i][j] == 1:
                    y1, x1 = i*100 + obj_ofs, j * 100 + obj_ofs
                    x2, y2 = x1 + 90, y1 + 90
                    polygon = [[x1, y1], [x2, y1], [x2, y2], [x1, y2], [x1, y1]]
                    draw.polygon([tuple(x) for x in polygon], fill=reward_color)
                elif self.maze[i][j] == 2:
                    y1, x1 = i*100 + obj_ofs, j * 100 + obj_ofs
                    x2, y2 = x1 + 100 - 2 * obj_ofs, y1 + 100 - 2 * obj_ofs
                    polygon = [[x1, y1], [x2, y1], [x2, y2], [x1, y2], [x1, y1]]
                    draw.polygon([tuple(x) for x in polygon], fill=obstacle_color)
        if 0 <= self.state[0] < 5 and \
                0 <= self.state[1] < 5:
            y1, x1 = self.state[0] * 100 + 20, self.state[1] * 100 + 35
            draw.text((x1, y1), "P", fill=(25, 245, 25), font=common.get_default_font(48))
        del draw
        return panel


class QLearningAgent(object):

    def __init__(self):
        n = 5
        self.q_table = {}
        for i in range(n):
            for j in range(n):
                self.q_table[(i, j)] = {a: 0 for a in range(4)}
        self.q_table[(-1, -1)] = {a: 0 for a in range(4)}
        self.lr = 0.01
        self.beta = 0.95
        self.explore_epsilon = 0.2

    def action(self, observe, explore=True):
        if explore and random.random() < self.explore_epsilon:
            return random.randint(0, 3)
        else:
            acts = list(self.q_table[observe].keys())
            random.shuffle(acts)
            x = acts[0]
            for act in acts[1:]:
                if self.q_table[observe][act] > self.q_table[observe][x]:
                    x = act
            return x

    def update(self, action, observe, reward, observe_next):
        """
        Q(s_t, a_t) = Q(s_t, a_t) + alpha * (beta * Max_a(Q(s_t+1, a)) + R - Q(s_t, a_t))
                    = (1 - alpha) * Q(s_t, a_t) + alpha * (beta * Max_a(Q(s_t+1, a)) + R)
        :param action:
        :param observe:
        :param reward: ,
        :param observe_next:
        :return:
        """
        self.q_table[observe][action] = (1 - self.lr) * self.q_table[observe][action] + \
                                        self.lr * (self.beta * max(self.q_table[observe_next].values()) + reward)


def train(env, agent):
    for eps in range(10000):
        env.reset()
        env.random_state()
        while True:
            state = tuple(env.state)
            act = agent.action(state)
            state_nxt, reward, is_done = env.step(act)
            state_nxt = tuple(state_nxt)
            # print(state, state_nxt, reward)
            agent.update(act, state, reward, state_nxt)
            if is_done:
                break
        # print(agent.q_table)


if __name__ == "__main__":
    env = MazeEnv()
    agent = QLearningAgent()
    train(env, agent)

    print(agent.q_table)

    import time
    while True:
        env.reset()
        cnt = 0
        while True:
            cnt += 1
            # cv2.imshow("a", cv2.cvtColor(np.asarray(env.render()), cv2.COLOR_RGB2BGR))
            # cv2.waitKey(50)
            env.render().show(str(cnt))
            time.sleep(2)
            state = tuple(env.state)
            act = agent.action(state, explore=False)
            observation, reward, is_done = env.step(act)
            if is_done:
                cnt += 1
                env.render().show(str(cnt))
                break
        # if cv2.waitKey() == ord('q'):
        #     break
        break



